/*
 * Copyright 2015, The Querydsl Team (http://www.querydsl.com/team)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.querydsl.sql;

import com.mysema.commons.lang.CloseableIterator;
import com.querydsl.core.*;
import com.querydsl.core.support.QueryMixin;
import com.querydsl.core.types.*;
import com.querydsl.core.types.dsl.Expressions;
import com.querydsl.core.types.dsl.SimpleExpression;
import com.querydsl.core.types.dsl.Wildcard;
import com.querydsl.core.util.ResultSetAdapter;
import java.lang.reflect.InvocationTargetException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.jetbrains.annotations.Nullable;

/**
 * {@code AbstractSQLQuery} is the base type for SQL query implementations
 *
 * @param <T> result type
 * @param <Q> concrete subtype
 * @author tiwe
 */
public abstract class AbstractSQLQuery<T, Q extends AbstractSQLQuery<T, Q>>
    extends ProjectableSQLQuery<T, Q> {

  protected static final String PARENT_CONTEXT =
      AbstractSQLQuery.class.getName() + "#PARENT_CONTEXT";

  private static final Logger logger = Logger.getLogger(AbstractSQLQuery.class.getName());

  private static final QueryFlag rowCountFlag =
      new QueryFlag(QueryFlag.Position.AFTER_PROJECTION, ", count(*) over() ");

  @Nullable private Supplier<Connection> connProvider;

  @Nullable private Connection conn;

  protected SQLListeners listeners;

  protected boolean useLiterals;

  private boolean getLastCell;

  private Object lastCell;

  private SQLListenerContext parentContext;

  private StatementOptions statementOptions = StatementOptions.DEFAULT;

  public AbstractSQLQuery(@Nullable Connection conn, Configuration configuration) {
    this(conn, configuration, new DefaultQueryMetadata());
  }

  public AbstractSQLQuery(
      @Nullable Connection conn, Configuration configuration, QueryMetadata metadata) {
    super(new QueryMixin<Q>(metadata, false), configuration);
    this.conn = conn;
    this.listeners = new SQLListeners(configuration.getListeners());
    this.useLiterals = configuration.getUseLiterals();
  }

  public AbstractSQLQuery(Supplier<Connection> connProvider, Configuration configuration) {
    this(connProvider, configuration, new DefaultQueryMetadata());
  }

  public AbstractSQLQuery(
      Supplier<Connection> connProvider, Configuration configuration, QueryMetadata metadata) {
    super(new QueryMixin<Q>(metadata, false), configuration);
    this.connProvider = connProvider;
    this.listeners = new SQLListeners(configuration.getListeners());
    this.useLiterals = configuration.getUseLiterals();
  }

  /**
   * Create an alias for the expression
   *
   * @param alias alias
   * @return this as alias
   */
  public SimpleExpression<T> as(String alias) {
    return Expressions.as(this, alias);
  }

  /**
   * Create an alias for the expression
   *
   * @param alias alias
   * @return this as alias
   */
  @SuppressWarnings("unchecked")
  public SimpleExpression<T> as(Path<?> alias) {
    return Expressions.as(this, (Path) alias);
  }

  /**
   * Add a listener
   *
   * @param listener listener to add
   */
  public void addListener(SQLListener listener) {
    listeners.add(listener);
  }

  @Override
  public long fetchCount() {
    try {
      return unsafeCount();
    } catch (SQLException e) {
      String error = "Caught " + e.getClass().getName();
      logger.log(Level.SEVERE, error, e);
      throw configuration.translate(e);
    }
  }

  /**
   * If you use forUpdate() with a backend that uses page or row locks, rows examined by the query
   * are write-locked until the end of the current transaction.
   *
   * <p>Not supported for SQLite and CUBRID
   *
   * @return the current object
   */
  public Q forUpdate() {
    QueryFlag forUpdateFlag = configuration.getTemplates().getForUpdateFlag();
    return addFlag(forUpdateFlag);
  }

  /**
   * FOR SHARE causes the rows retrieved by the SELECT statement to be locked as though for update.
   *
   * <p>Supported by MySQL, PostgreSQL, SQLServer.
   *
   * @return the current object
   * @throws QueryException if the FOR SHARE is not supported.
   */
  public Q forShare() {
    return forShare(false);
  }

  /**
   * FOR SHARE causes the rows retrieved by the SELECT statement to be locked as though for update.
   *
   * <p>Supported by MySQL, PostgreSQL, SQLServer.
   *
   * @param fallbackToForUpdate if the FOR SHARE is not supported and this parameter is <code>true
   *     </code>, the {@link #forUpdate()} functionality will be used.
   * @return the current object
   * @throws QueryException if the FOR SHARE is not supported and <i>fallbackToForUpdate</i> is set
   *     to <code>false</code>.
   */
  public Q forShare(boolean fallbackToForUpdate) {
    SQLTemplates sqlTemplates = configuration.getTemplates();

    if (sqlTemplates.isForShareSupported()) {
      QueryFlag forShareFlag = sqlTemplates.getForShareFlag();
      return addFlag(forShareFlag);
    }

    if (fallbackToForUpdate) {
      return forUpdate();
    }

    throw new QueryException("Using forShare() is not supported");
  }

  @Override
  protected SQLSerializer createSerializer() {
    SQLSerializer serializer = new SQLSerializer(configuration);
    serializer.setUseLiterals(useLiterals);
    return serializer;
  }

  @Nullable
  private <U> U get(ResultSet rs, Expression<?> expr, int i, Class<U> type) throws SQLException {
    return configuration.get(rs, expr instanceof Path ? (Path<?>) expr : null, i, type);
  }

  private void set(PreparedStatement stmt, Path<?> path, int i, Object value) throws SQLException {
    configuration.set(stmt, path, i, value);
  }

  /**
   * Called to create and start a new SQL Listener context
   *
   * @param connection the database connection
   * @param metadata the meta data for that context
   * @return the newly started context
   */
  protected SQLListenerContextImpl startContext(Connection connection, QueryMetadata metadata) {
    SQLListenerContextImpl context = new SQLListenerContextImpl(metadata, connection);
    if (parentContext != null) {
      context.setData(PARENT_CONTEXT, parentContext);
    }
    listeners.start(context);
    return context;
  }

  /**
   * Called to make the call back to listeners when an exception happens
   *
   * @param context the current context in play
   * @param e the exception
   */
  protected void onException(SQLListenerContextImpl context, Exception e) {
    context.setException(e);
    listeners.exception(context);
  }

  /**
   * Called to end a SQL listener context
   *
   * @param context the listener context to end
   */
  protected void endContext(SQLListenerContext context) {
    listeners.end(context);
  }

  /**
   * Get the results as a JDBC ResultSet
   *
   * @param exprs the expression arguments to retrieve
   * @return results as ResultSet
   * @deprecated Use @{code select(..)} to define the projection and {@code getResults()} to obtain
   *     the result set
   */
  @Deprecated
  public ResultSet getResults(Expression<?>... exprs) {
    if (exprs.length > 0) {
      queryMixin.setProjection(exprs);
    }
    return getResults();
  }

  /**
   * Get the results as a JDBC ResultSet
   *
   * @return results as ResultSet
   */
  public ResultSet getResults() {
    final SQLListenerContextImpl context = startContext(connection(), queryMixin.getMetadata());
    String queryString = null;
    List<Object> constants = Collections.emptyList();

    try {
      listeners.preRender(context);
      SQLSerializer serializer = serialize(false);
      queryString = serializer.toString();
      logQuery(queryString, serializer.getConstants());
      context.addSQL(getSQL(serializer));
      listeners.rendered(context);

      listeners.notifyQuery(queryMixin.getMetadata());

      constants = serializer.getConstants();

      listeners.prePrepare(context);
      final PreparedStatement stmt = getPreparedStatement(queryString);
      setParameters(stmt, constants, serializer.getConstantPaths(), getMetadata().getParams());
      context.addPreparedStatement(stmt);
      listeners.prepared(context);

      listeners.preExecute(context);
      final ResultSet rs = stmt.executeQuery();
      listeners.executed(context);

      return new ResultSetAdapter(rs) {
        @Override
        public void close() throws SQLException {
          try {
            super.close();
          } finally {
            stmt.close();
            reset();
            endContext(context);
          }
        }
      };
    } catch (SQLException e) {
      onException(context, e);
      reset();
      endContext(context);
      throw configuration.translate(queryString, constants, e);
    }
  }

  private PreparedStatement getPreparedStatement(String queryString) throws SQLException {
    PreparedStatement statement = connection().prepareStatement(queryString);
    if (statementOptions.getFetchSize() != null) {
      statement.setFetchSize(statementOptions.getFetchSize());
    }
    if (statementOptions.getMaxFieldSize() != null) {
      statement.setMaxFieldSize(statementOptions.getMaxFieldSize());
    }
    if (statementOptions.getQueryTimeout() != null) {
      statement.setQueryTimeout(statementOptions.getQueryTimeout());
    }
    if (statementOptions.getMaxRows() != null) {
      statement.setMaxRows(statementOptions.getMaxRows());
    }
    return statement;
  }

  protected Configuration getConfiguration() {
    return configuration;
  }

  @SuppressWarnings("unchecked")
  @Override
  public CloseableIterator<T> iterate() {
    Expression<T> expr = (Expression<T>) queryMixin.getMetadata().getProjection();
    return iterateSingle(queryMixin.getMetadata(), expr);
  }

  @SuppressWarnings("unchecked")
  private CloseableIterator<T> iterateSingle(
      QueryMetadata metadata, @Nullable final Expression<T> expr) {
    SQLListenerContextImpl context = startContext(connection(), queryMixin.getMetadata());
    String queryString = null;
    List<Object> constants = Collections.emptyList();

    try {
      listeners.preRender(context);
      SQLSerializer serializer = serialize(false);
      queryString = serializer.toString();
      logQuery(queryString, serializer.getConstants());
      context.addSQL(getSQL(serializer));
      listeners.rendered(context);

      listeners.notifyQuery(queryMixin.getMetadata());
      constants = serializer.getConstants();

      listeners.prePrepare(context);
      final PreparedStatement stmt = getPreparedStatement(queryString);
      setParameters(stmt, constants, serializer.getConstantPaths(), metadata.getParams());
      context.addPreparedStatement(stmt);
      listeners.prepared(context);

      listeners.preExecute(context);
      final ResultSet rs = stmt.executeQuery();
      listeners.executed(context);

      if (expr == null) {
        return new SQLResultIterator<T>(configuration, stmt, rs, listeners, context) {
          @Override
          public T produceNext(ResultSet rs) throws Exception {
            return (T) rs.getObject(1);
          }
        };
      } else if (expr instanceof FactoryExpression) {
        return new SQLResultIterator<T>(configuration, stmt, rs, listeners, context) {
          @Override
          public T produceNext(ResultSet rs) throws Exception {
            return newInstance((FactoryExpression<T>) expr, rs, 0);
          }
        };
      } else if (expr.equals(Wildcard.all)) {
        return new SQLResultIterator<T>(configuration, stmt, rs, listeners, context) {
          @Override
          public T produceNext(ResultSet rs) throws Exception {
            Object[] rv = new Object[rs.getMetaData().getColumnCount()];
            for (int i = 0; i < rv.length; i++) {
              rv[i] = rs.getObject(i + 1);
            }
            return (T) rv;
          }
        };
      } else {
        return new SQLResultIterator<T>(configuration, stmt, rs, listeners, context) {
          @Override
          public T produceNext(ResultSet rs) throws Exception {
            return get(rs, expr, 1, expr.getType());
          }
        };
      }

    } catch (SQLException e) {
      onException(context, e);
      endContext(context);
      throw configuration.translate(queryString, constants, e);
    } catch (RuntimeException e) {
      logger.log(Level.SEVERE, "Caught " + e.getClass().getName() + " for " + queryString);
      onException(context, e);
      endContext(context);
      throw e;
    } finally {
      reset();
    }
  }

  @SuppressWarnings("unchecked")
  @Override
  public List<T> fetch() {
    Expression<T> expr = (Expression<T>) queryMixin.getMetadata().getProjection();
    SQLListenerContextImpl context = startContext(connection(), queryMixin.getMetadata());
    String queryString = null;
    List<Object> constants = Collections.emptyList();

    try {
      listeners.preRender(context);
      SQLSerializer serializer = serialize(false);
      queryString = serializer.toString();
      logQuery(queryString, serializer.getConstants());
      context.addSQL(getSQL(serializer));
      listeners.rendered(context);

      listeners.notifyQuery(queryMixin.getMetadata());
      constants = serializer.getConstants();

      listeners.prePrepare(context);
      try (PreparedStatement stmt = getPreparedStatement(queryString)) {
        setParameters(
            stmt, constants, serializer.getConstantPaths(), queryMixin.getMetadata().getParams());
        context.addPreparedStatement(stmt);
        listeners.prepared(context);

        listeners.preExecute(context);
        try (ResultSet rs = stmt.executeQuery()) {
          listeners.executed(context);
          lastCell = null;
          final List<T> rv = new ArrayList<T>();
          if (expr instanceof FactoryExpression) {
            FactoryExpression<T> fe = (FactoryExpression<T>) expr;
            while (rs.next()) {
              if (getLastCell) {
                lastCell = rs.getObject(fe.getArgs().size() + 1);
                getLastCell = false;
              }
              rv.add(newInstance(fe, rs, 0));
            }
          } else if (expr.equals(Wildcard.all)) {
            while (rs.next()) {
              Object[] row = new Object[rs.getMetaData().getColumnCount()];
              if (getLastCell) {
                lastCell = rs.getObject(row.length);
                getLastCell = false;
              }
              for (int i = 0; i < row.length; i++) {
                row[i] = rs.getObject(i + 1);
              }
              rv.add((T) row);
            }
          } else {
            while (rs.next()) {
              if (getLastCell) {
                lastCell = rs.getObject(2);
                getLastCell = false;
              }
              rv.add(get(rs, expr, 1, expr.getType()));
            }
          }
          return rv;
        } catch (IllegalAccessException | InstantiationException | InvocationTargetException e) {
          onException(context, e);
          throw new QueryException(e);
        } catch (SQLException e) {
          onException(context, e);
          throw configuration.translate(queryString, constants, e);
        }
      }
    } catch (SQLException e) {
      onException(context, e);
      throw configuration.translate(queryString, constants, e);
    } finally {
      endContext(context);
      reset();
    }
  }

  @SuppressWarnings("unchecked")
  @Override
  public QueryResults<T> fetchResults() {
    parentContext = startContext(connection(), queryMixin.getMetadata());
    Expression<T> expr = (Expression<T>) queryMixin.getMetadata().getProjection();
    QueryModifiers originalModifiers = queryMixin.getMetadata().getModifiers();
    try {
      if (configuration.getTemplates().isCountViaAnalytics()
          && queryMixin.getMetadata().getGroupBy().isEmpty()) {
        List<T> results;
        try {
          queryMixin.addFlag(rowCountFlag);
          getLastCell = true;
          results = fetch();
        } finally {
          queryMixin.removeFlag(rowCountFlag);
        }
        long total;
        if (!results.isEmpty()) {
          if (lastCell instanceof Number) {
            total = ((Number) lastCell).longValue();
          } else {
            throw new IllegalStateException("Unsupported lastCell instance " + lastCell);
          }
        } else {
          total = fetchCount();
        }
        return new QueryResults<T>(results, originalModifiers, total);

      } else {
        queryMixin.setProjection(expr);
        long total = fetchCount();
        if (total > 0) {
          return new QueryResults<T>(fetch(), originalModifiers, total);
        } else {
          return QueryResults.emptyResults();
        }
      }

    } finally {
      endContext(parentContext);
      reset();
      getLastCell = false;
      parentContext = null;
    }
  }

  private <RT> RT newInstance(FactoryExpression<RT> c, ResultSet rs, int offset)
      throws InstantiationException,
          IllegalAccessException,
          InvocationTargetException,
          SQLException {
    Object[] args = new Object[c.getArgs().size()];
    for (int i = 0; i < args.length; i++) {
      args[i] = get(rs, c.getArgs().get(i), offset + i + 1, c.getArgs().get(i).getType());
    }
    return c.newInstance(args);
  }

  private void reset() {}

  protected void setParameters(
      PreparedStatement stmt,
      List<?> objects,
      List<Path<?>> constantPaths,
      Map<ParamExpression<?>, ?> params) {
    if (objects.size() != constantPaths.size()) {
      throw new IllegalArgumentException(
          "Expected " + objects.size() + " paths, but got " + constantPaths.size());
    }
    for (int i = 0; i < objects.size(); i++) {
      Object o = objects.get(i);
      try {
        if (o instanceof ParamExpression) {
          if (!params.containsKey(o)) {
            throw new ParamNotSetException((ParamExpression<?>) o);
          }
          o = params.get(o);
        }
        set(stmt, constantPaths.get(i), i + 1, o);
      } catch (SQLException e) {
        throw configuration.translate(e);
      }
    }
  }

  private long unsafeCount() throws SQLException {
    SQLListenerContextImpl context = startContext(connection(), getMetadata());
    String queryString = null;
    List<Object> constants = Collections.emptyList();
    PreparedStatement stmt = null;
    ResultSet rs = null;

    try {
      listeners.preRender(context);
      SQLSerializer serializer = serialize(true);
      queryString = serializer.toString();
      logQuery(queryString, serializer.getConstants());
      context.addSQL(getSQL(serializer));
      listeners.rendered(context);

      constants = serializer.getConstants();
      listeners.prePrepare(context);

      stmt = getPreparedStatement(queryString);
      setParameters(stmt, constants, serializer.getConstantPaths(), getMetadata().getParams());

      context.addPreparedStatement(stmt);
      listeners.prepared(context);

      listeners.preExecute(context);
      rs = stmt.executeQuery();
      boolean hasResult = rs.next();
      listeners.executed(context);

      if (hasResult) {
        return rs.getLong(1);
      } else {
        return 0;
      }
    } catch (SQLException e) {
      onException(context, e);
      throw configuration.translate(queryString, constants, e);
    } finally {
      try {
        if (rs != null) {
          rs.close();
        }
      } finally {
        if (stmt != null) {
          stmt.close();
        }
      }
      endContext(context);
    }
  }

  protected void logQuery(String queryString, Collection<Object> parameters) {
    if (logger.isLoggable(Level.FINE)) {
      String normalizedQuery = queryString.replace('\n', ' ');
      logger.fine(normalizedQuery);
    }
  }

  private Connection connection() {
    if (conn == null) {
      if (connProvider != null) {
        conn = connProvider.get();
      } else {
        throw new IllegalStateException("No connection provided");
      }
    }
    return conn;
  }

  /**
   * Set whether literals are used in SQL strings instead of parameter bindings (default: false)
   *
   * <p>Warning: When literals are used, prepared statement won't have any parameter bindings and
   * also batch statements will only be simulated, but not executed as actual batch statements.
   *
   * @param useLiterals true for literals and false for bindings
   */
  public void setUseLiterals(boolean useLiterals) {
    this.useLiterals = useLiterals;
  }

  @Override
  protected void clone(Q query) {
    super.clone(query);
    this.useLiterals = query.useLiterals;
    this.listeners = new SQLListeners(query.listeners);
  }

  @Override
  public Q clone() {
    return this.clone(this.conn);
  }

  public abstract Q clone(Connection connection);

  /**
   * Set the options to be applied to the JDBC statements of this query
   *
   * @param statementOptions options to be applied to statements
   */
  public void setStatementOptions(StatementOptions statementOptions) {
    this.statementOptions = statementOptions;
  }
}
